e3nn repository¶@misc{mario_geiger_2019_3348277,
author = {Mario Geiger and
Tess Smidt and
Wouter Boomsma and
Maurice Weiler and
Michał Tyszkiewicz and
Jes Frellsen and
Benjamin K. Miller and
Josh Rackers},
title = {e3nn/e3nn: Point cloud support},
month = jul,
year = 2019,
doi = {10.5281/zenodo.3348277},
url = {https://doi.org/10.5281/zenodo.3348277}
}
SphericalTensor class like we did in data_types.ipynb.¶%load_ext autoreload
%autoreload 2
import torch
from spherical import SphericalTensor
torch.set_default_dtype(torch.float64)
Rs = [(1, 1)]
sum_Ls = sum((2 * L + 1) for mult, L in Rs for _ in range(mult))
signal_1 = torch.zeros(sum_Ls)
signal_1[0] = 1. # y
signal_2 = torch.zeros(sum_Ls)
signal_2[2] = 1. # x
sphten_1 = SphericalTensor(signal_1, Rs)
sphten_2 = SphericalTensor(signal_2, Rs)
import plotly
from plotly.subplots import make_subplots
n = 50
def plot_operation(input1, input2, output):
rows = 1
cols = 3
specs = [[{'is_3d': True} for i in range(cols)]
for j in range(rows)]
fig = make_subplots(rows=rows, cols=cols, specs=specs)
for i, sphten in enumerate([input1, input2, output]):
trace = sphten.plot(relu=False, n=n)
trace.showscale = False
fig.add_trace(trace, row=1, col=i + 1)
fig.update_layout(scene_aspectmode='data')
return fig
new_sphten = sphten_1 + sphten_2
# plots functions proportional to y, x, and (x + y)
fig = plot_operation(sphten_1, sphten_2, new_sphten)
fig.show()
dot_product = sphten_1 * sphten_2 # These functions are orthogonal
print(dot_product)
dot_product = sphten_1 * sphten_1 # These functions are identical
print(dot_product)
new_sphten = sphten_1 @ sphten_2
print("input1 Rs", sphten_1.Rs)
print("input2 Rs", sphten_2.Rs)
print("output Rs", new_sphten.Rs)
print("")
# plots functions proportional to y, x, and 1 + xy
print("Now we have contributions to z (cross product) and xy (outer product).")
fig = plot_operation(sphten_1, sphten_2, new_sphten)
print("SH:", " 1 y z x xy yz * zx %",)
print("new", new_sphten.signal.numpy().round(3))
print("* == 2z^2 - x^2 - y^2")
print("% == x^2 - y^2")
fig.show()